open Base
module Path = Draw_tree.Path
module Update = Draw_tree.Update

module Sink_id : sig
  include Identifiable.S
  val gen : unit -> t
  val anticipate : int -> t
end = struct
  include Int
  let to_string n = Printf.sprintf "<sink:%d>" n
  let next = ref 0
  let gen () =
    let v = !next in
    Int.incr next ; v
  let anticipate n = n + !next
end

(*** components ***)

type component =
  | Tree of component list * (Draw_tree.t list -> Draw_tree.t)
  | Dynamic of component Behavior.t
  | With_fold of Source.value
               * (Event.t -> Source.value -> Source.value)
               * (Source.t -> component)

(* sinks and connections *)

type sink =
  { behavior : component Behavior.t
  ; sub_sinks : Sink_id.t list
  }

type sinks = (Sink_id.t, Path.t * sink, Sink_id.comparator_witness) Map.t
type connections = (Source.t, Sink_id.t list, Source.comparator_witness) Map.t
type handlers = Event.handler Event.map

type state =
  { sources : Source.State.t
  ; sinks : sinks
  ; conns : connections
  ; handlers : handlers
  }

let empty_sinks = Map.empty (module Sink_id)
let empty_conns = Source.empty_map
let merge_sinks s1 s2 = Map.merge_skewed s1 s2 ~combine:(fun ~key:_ _ x -> x)
let merge_conns c1 c2 = Map.merge_skewed c1 c2 ~combine:(fun ~key:_ -> (@))

let empty_state sources =
  { sources
  ; sinks = empty_sinks
  ; conns = empty_conns
  ; handlers = Event.empty_map }

let connect_behavior bhv sink_id conns =
  Sequence.fold (Behavior.dependencies bhv)
    ~init:conns
    ~f:(fun con src -> Map.add_multi con ~key:src ~data:sink_id)

let remove_sink_by_id st sink_id =
  let not_this_sink id = not (Sink_id.equal id sink_id) in
  { st
    with sinks = Map.remove st.sinks sink_id
       ; conns = Map.map st.conns ~f:(List.filter ~f:not_this_sink) }

let rec recursive_subsink_ids sink st =
  List.concat_map sink.sub_sinks
    ~f:(fun sink_id ->
        match Map.find st.sinks sink_id with
        | None -> []
        | Some((_, sub_sink)) ->
           sink_id::(recursive_subsink_ids sub_sink st))

(* mounting and other important logic *)

let mount component sources : Source.State.t * sinks * connections * Draw_tree.t =
  let sources = ref sources in
  let sinks = ref empty_sinks in
  let conns = ref empty_conns in
  let rec inst = function
    | Tree(coms, make_tree) ->
       let new_sinks, subtrees =
         List.fold_map coms ~init:[]
           ~f:(fun new_sinks com ->
               let (new_sinks', tree) = inst com in
               (new_sinks' @ new_sinks, tree))
       in
       (new_sinks, make_tree subtrees)

    | Dynamic(b) ->
       let sink_id = Sink_id.gen () in
       (* create the tree that will be initially visible *)
       let init_com = Behavior.sample b !sources in
       let sub_sinks, init_tree = inst init_com in
       (* create a new sink *)
       let sink = { behavior = b ; sub_sinks } in
       let path = Draw_tree.path init_tree in
       (* associate the behavior to the sink *)
       let () = sinks := Map.add_exn !sinks ~key:sink_id ~data:(path, sink) in
       let () = conns := connect_behavior b sink_id !conns in
       ([sink_id], init_tree)

    | With_fold(init, f, make_component) ->
       (* create source with initial value *)
       let src = Source.create () in
       let () = sources := Source.State.set src init !sources in
       (* create child component *)
       let com = make_component src in
       let new_sinks, tree = inst com in
       ignore (f) [@ocaml.warning "-5"];
       (new_sinks, tree)
  in
  let _, tree = inst component in
  !sources, !sinks, !conns, tree

let mount component st : state * Draw_tree.t =
  let sources, sinks, conns, tree = mount component st.sources in
  ({ st with sources
           ; sinks = merge_sinks st.sinks sinks
           ; conns = merge_conns st.conns conns },
   tree)

let update_sink sink st =
  (* remove old sinks *)
  let st = List.fold (recursive_subsink_ids sink st)
             ~init:st ~f:remove_sink_by_id in
  (* update component and re-mount *)
  let component' = Behavior.sample sink.behavior st.sources in
  let st, tree = mount component' st in
  (st, [ Update.Unmount
       ; Update.Mount(tree) ])

let source_changed src value st : state * Draw_tree.updates =
  let st = { st with sources = Source.State.set src value st.sources } in
  (* find affected sinks *)
  let sink_ids = Map.find_multi st.conns src in
  let sinks = Sequence.(filter_map (of_list sink_ids) ~f:(Map.find st.sinks)) in
  (* apply updates *)
  let st, upds =
    Sequence.fold sinks ~init:(st, [])
      ~f:(fun (st, upds) (path, sink) ->
          let (st, new_upds) = update_sink sink st in
          let new_upds = List.map new_upds ~f:(fun u -> (path, u)) in
          (st, List.append new_upds upds))
  in
  (st, Sequence.of_list upds)

let init sources component : state * Draw_tree.t =
  let st, tree = mount component (empty_state sources) in
  let () = Draw_tree.refresh_paths tree in
  (st, tree)

(*** component combinators ***)

let arm dir coms =
  Tree(coms, Draw_tree.arm dir)

let figure b =
  let f fig = Tree([], fun _ -> Draw_tree.figure fig) in
  Dynamic(Behavior.map b ~f)

let rect b =
  figure Behavior.(b >>| fun (dim, c) -> Draw_tree.Rect(dim, c))

let cfigure fig =
  Tree([], fun _ -> Draw_tree.figure fig)

let crect dim c =
  cfigure (Draw_tree.Rect(dim, c))

let fold ~init ~f com =
  With_fold(init, f, fun src -> com (Behavior.of_source src))
